import time
from agh_data import AGHData
from agh_obj import *
from agh_dispatch import *
import csv
import toolbox

class Simulation():
    def __init__(self) -> None:
        self.clock:Clock = Clock()
        self.cur_operation_index = 0
        self.cur_dispatch_method_index = 0
        self.ga_dispatch_method_index = 1
        self.dispatch_methods: List[DispatchMethod] = [FCFS(), GA()]
        self.dispatch_plans: List[List[Operation]] = [[] for _ in self.dispatch_methods]
        self.cur_dispatch_plan:List[Operation] = []
        self.data:AGHData = AGHData()

    def forward(self) -> Tuple[bool, str]:
        if self.data.flight_data[-1].is_done == True:
            return False, "调度模拟结束"

        self.clock.time_forward()
        self.deal_with_vehicle()
        self.working_gate_forward()
        if not self.dispatch_vehicle():
            return False, "调度方案错误"

        return True, ""

    def init_simulate(self, is_new_start:bool = False, is_change_dispatch_method:bool = False) -> Tuple[bool, str]:
        self.cur_operation_index = 0
        self.data.refresh_data()
        success = True
        message = ""
        cur_dispatch_method = self.dispatch_methods[self.cur_dispatch_method_index]
        if is_new_start:
            if is_change_dispatch_method:
                if len(self.cur_dispatch_plan) == 0:
                    success, self.cur_dispatch_plan, message = cur_dispatch_method.generate_dispath_plan(self.data)
            else:
                self.dispatch_plans = [[] for _ in self.dispatch_methods]
                success, self.cur_dispatch_plan, message = cur_dispatch_method.generate_dispath_plan(self.data)
                
            self.dispatch_plans[self.cur_dispatch_method_index] = self.cur_dispatch_plan

        if success:
            self.clock.set_time_stamp(self.cur_dispatch_plan[0].time - 3 * 60)
        return success, message
    
    def set_dispatch_method(self, index:int):
        self.cur_dispatch_method_index = index
        self.cur_dispatch_plan = self.dispatch_plans[index]

    def dispatch_vehicle(self) -> bool:        
        while self.cur_operation_index < len(self.cur_dispatch_plan) and \
            self.clock.time == self.cur_dispatch_plan[self.cur_operation_index].time:
            op = self.cur_dispatch_plan[self.cur_operation_index]
            gate = self.data.gate_data[op.gate_index]
            chosed_vehicle = self.data.vehicle_data[op.vehicle_type.value][op.vehicle_index]

            if not gate.is_working:
                gate.is_working = True
                gate.flight = self.data.flight_data[op.flight_index]
                gate.vehicles_work_done =  [False for _ in VehicleType]
                gate.vehicles = gate.next_vehicles
                gate.next_vehicles = [None for _ in VehicleType]
                gate.start_serving_time = time.time()
                gate.vehicles[op.vehicle_type.value] = chosed_vehicle

                if gate.is_near_gate:
                    gate.vehicles_work_done[VehicleType.ShuttleBus.value] = True
            else:
                if gate.vehicles[chosed_vehicle.type.value] != None:
                    gate.next_vehicles[chosed_vehicle.type.value] = chosed_vehicle
                else:
                    gate.vehicles[chosed_vehicle.type.value] = chosed_vehicle
            
        
            mess = ""
            if chosed_vehicle.capcity == 0:
                mess = "0 cap "

            if chosed_vehicle.state != VehicleState.Standby and chosed_vehicle.state != VehicleState.Complete:
                print(self.clock.get_full_time())
                print(mess, self.data.gate_data[chosed_vehicle.gate_index].name, self.data.gate_data[op.gate_index], chosed_vehicle.type, chosed_vehicle.index, chosed_vehicle.state, chosed_vehicle.get_progress_value())
                print("车辆状态错误")
                return False

            chosed_vehicle.state = VehicleState.Coming
            chosed_vehicle.gate_index = op.gate_index
            chosed_vehicle.gate_name = gate.name
            chosed_vehicle.start_time = op.time
            chosed_vehicle.time_cost = ceil(op.distance / chosed_vehicle.speed)
            chosed_vehicle.end_time = op.time + chosed_vehicle.time_cost
            chosed_vehicle.back_path = self.data.get_path_park_to_gate(op.gate_index, chosed_vehicle.park_index, reversed=True)

            if op.is_source_a_park:
                chosed_vehicle.path = self.data.get_path_park_to_gate(op.gate_index, chosed_vehicle.park_index)
            else:
                if op.source_index != op.gate_index:
                    chosed_vehicle.path = self.data.get_path_between_gates(op.source_index, op.gate_index)
                else:
                    chosed_vehicle.state = VehicleState.Waiting
            
            cost = op.distance * Vehicle.property_of_type[op.vehicle_type].cost
            if chosed_vehicle.capcity == 1:
                cost += chosed_vehicle.back_path.length * Vehicle.property_of_type[op.vehicle_type].cost

            self.data.cost_data.append(cost)
            self.data.cost_sum += cost
            self.data.cost_avg = self.data.cost_sum / len(self.data.cost_data)
            self.data.cost_max = max(cost, self.data.cost_max)

            self.cur_operation_index += 1
        return True

    def working_gate_forward(self):
        gate_in_working = list(filter(lambda gate: gate.is_working, self.data.gate_data))
        for gate in gate_in_working:
            if gate.flight.arrival_time_stamp == self.clock.time:
                gate.flight.is_arrived = True
            
            gate.time_forward(self.clock.time)
            if gate.vehicles_work_done.count(False) == 0:
                gate.is_working = False
                gate.flight.is_done = True
                gate.flight = None

    def deal_with_vehicle(self):
        for vehicles in self.data.vehicle_data:
            for vehicle in vehicles:
                if vehicle.state == VehicleState.Backing or \
                   vehicle.state == VehicleState.Coming or \
                   vehicle.state == VehicleState.Recovering:

                    vehicle.time_forward(self.clock.time)

    def get_operation_table_data(self, get_all:bool = True, num:int = -1) -> List[Tuple[str, str, str]]:
        if get_all:
            return [self.get_operation_info(op) for op in self.cur_dispatch_plan]
        
        start = self.cur_operation_index
        end = min(start + num, len(self.cur_dispatch_plan))
        if start == end:
            return []
        return [self.get_operation_info(self.cur_dispatch_plan[i]) for i in range(start, end)]
    
    def get_operation_info(self, op:Operation) -> List[str]:
        vehicle_name_and_index = VEHICLE_NAMES[op.vehicle_type.value] + str(op.vehicle_index)
        launch_time = Clock.get_time_str_with_HMS(op.time)
        if op.is_source_a_park:
            trace = self.data.park_data[op.source_index].name + "->" + self.data.gate_data[op.gate_index].name
        else:
            trace = self.data.gate_data[op.source_index].name + "->" + self.data.gate_data[op.gate_index].name
        
        return [vehicle_name_and_index, launch_time, trace]

    def get_dispatch_method_names(self) -> List[str]:
        return [d.name for d in self.dispatch_methods]
    
    def save_dipatch_plan(self, file_name:str):
        with open(file_name, 'w', newline='') as f:
            csv_writer = csv.writer(f)
            csv_writer.writerows(self.get_operation_table_data())
        
    def show_cost_chart(self):
        chart_data = []
        legends = []
        for index, plan in enumerate(self.dispatch_plans):
            if len(plan) == 0:
                success, ops, _ = self.dispatch_methods[index].generate_dispath_plan(self.data)
                if not success:
                    return
                self.dispatch_plans[index] = ops
            else:
                ops = plan   
            cost, cost_sum, cost_max = get_cost(ops, self.data)
            chart_data.append(cost)
            legends.append(f"{self.dispatch_methods[index].__class__.__name__} \
                    Sum:{cost_sum:.2f} Avg:{cost_sum / len(cost):.2f} Max:{cost_max:.2f}")
        toolbox.open_chart_dialog(chart_data, legends, "调度代价图", "op_index", "cost")            
    
    def set_ga_parms(self, pop_num:int, itera_times:int, cross_rate:float, mutation_rate:float):
        ga:GA = self.dispatch_methods[self.ga_dispatch_method_index]
        ga.pop_num = pop_num
        ga.itera_times = itera_times
        ga.cross_rate = cross_rate
        ga.mutation_rate = mutation_rate
    
    def get_ga_parms(self) -> Tuple[int, int, float, float]:
        ga:GA = self.dispatch_methods[self.ga_dispatch_method_index]
        return ga.pop_num, ga.itera_times, ga.cross_rate, ga.mutation_rate
